# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Produce the base dataset. """

import numpy as np
from .mnist import create_mnist_dataset
from .imagenet2012 import create_imagenet_dataset


class CustomDataset():
    def __init__(self):
        self.data = np.random.sample((5, 2))
        self.label = np.random.sample((5, 1))

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)


def create_dataset(config, states="train"):
    """
    A source dataset for reading and parsing dataset

    Args:
        config: config parameter
        states: Train or Eval

    Returns: dataset

    """
    if states == "train":
        dataset_config = config.TRAIN
    else:
        dataset_config = config.VALID
    if dataset_config.dataset_name == "mnist":
        dataset = create_mnist_dataset(dataset_config.data_dir,
                                       dataset_config.batch_size)
        return dataset
    elif dataset_config.dataset_name == "imagenet2012":
        dataset = create_imagenet_dataset(dataset_config.data_dir,
                                          dataset_config.batch_size)
        return dataset
    return None
